package drds.plus.sql_process.optimizer.execute_plan_optimizer;

import drds.plus.common.jdbc.Parameters;
import drds.plus.sql_process.abstract_syntax_tree.execute_plan.ExecutePlan;
import drds.plus.sql_process.abstract_syntax_tree.execute_plan.query.Join;
import drds.plus.sql_process.abstract_syntax_tree.execute_plan.query.MergeQuery;
import drds.plus.sql_process.abstract_syntax_tree.execute_plan.query.Query;
import drds.plus.sql_process.abstract_syntax_tree.execute_plan.query.QueryWithIndex;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.Item;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.function.Function;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.function.FunctionType;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
 * avg变成count + sum 要改变columns结构
 */
public class AvgOptimizer implements ExecutePlanOptimizer {

    public AvgOptimizer() {
    }

    /**
     * 把query中的avg换成count，sum
     */

    public ExecutePlan optimize(ExecutePlan executePlan, Parameters parameters, Map<String, Object> extraCmd) {
        if (executePlan instanceof MergeQuery && ((MergeQuery) executePlan).getExecutePlanList().size() > 1) {
            for (ExecutePlan executePlan1 : ((MergeQuery) executePlan).getExecutePlanList()) {
                expendAvgFunction(executePlan1);
            }

            for (ExecutePlan executePlan1 : ((MergeQuery) executePlan).getExecutePlanList()) {
                this.optimize(executePlan1, parameters, extraCmd);
            }
        } else if (executePlan instanceof Join) {
            Join join = (Join) executePlan;
            // join函数，采取map模式，不需要处理avg展开
            // 递归处理子节点
            this.optimize(join.getLeftNode(), parameters, extraCmd);
            this.optimize(join.getRightNode(), parameters, extraCmd);
        } else if (executePlan instanceof QueryWithIndex) {
            QueryWithIndex queryWithIndex = (QueryWithIndex) executePlan;
            // 如果是子查询,采取map模式，不需要处理avg展开
            if (queryWithIndex.isSubQuery()) {
                this.optimize(queryWithIndex.getSubQuery(), parameters, extraCmd);// 递归处理子节点
            }
        }

        return executePlan;
    }

    private boolean functionArgsHasAvgFunction(Function function) {
        for (Object args : function.getArgList()) {
            if (args instanceof Function && ((Function) args).getColumnName().startsWith("avg(")) {
                return true;
            }
        }

        return false;
    }

    /**
     * 将Avg函数展开为sum/count
     */
    private void expendAvgFunction(ExecutePlan executePlan) {
        if (executePlan instanceof QueryWithIndex || executePlan instanceof Join) {
            List<Item> add = new ArrayList();
            List<Item> remove = new ArrayList();
            for (Object object : ((Query) executePlan).getItemList()) {
                Item item = (Item) object;
                if (item instanceof Function) {
                    if (item.getColumnName().startsWith("avg(")) {
                        Function sum = (Function) item.copy();
                        sum.setExtraFunction(null);
                        sum.setFunctionName("sum");
                        sum.setColumnName(item.getColumnName().replace("avg(", "sum("));
                        if (sum.getAlias() != null) {
                            sum.setAlias(sum.getAlias() + "1");// 加个后缀1
                        }

                        Function count = (Function) item.copy();
                        count.setExtraFunction(null);
                        count.setFunctionName("count");
                        count.setColumnName(item.getColumnName().replace("avg(", "count("));
                        if (count.getAlias() != null) {
                            count.setAlias(count.getAlias() + "2");// 加个后缀2
                        }

                        add.add(count);
                        add.add(sum);

                        remove.add(item);
                    } else {
                        // 删除底下AVG的相关函数，比如 1 + avg(ID)
                        // 目前这个只能上层来进行计算
                        // 可能的风险：还未支持的Function计算
                        if (FunctionType.scalar.equals(((Function) item).getFunctionType()) && functionArgsHasAvgFunction((Function) item)) {
                            remove.add(item);
                        }
                    }
                }
            }

            if (!remove.isEmpty()) {
                ((Query) executePlan).getItemList().removeAll(remove);
                ((Query) executePlan).getItemList().addAll(add);
            }
        }
    }
}
